import pandas as pd
from datasets import load_dataset, Dataset, DatasetDict
from sklearn.metrics import cohen_kappa_score
import warnings
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)
pd.set_option('future.no_silent_downcasting', True)
pd.options.mode.chained_assignment = None 

# Import and organize the dataset
ds = load_dataset('mlburnham/Pol_NLI')
train = ds['train'].to_pandas()
val = ds['validation'].to_pandas()
test = ds['test'].to_pandas()
alldata = pd.concat([train, val, test])
alldata.reset_index(drop = True, inplace = True)

splits = ['train', 'validation', 'test']
all_datasets = pd.concat([ds[split].to_pandas() for split in splits])

def describe_dataset(data_dict, split):
    df = data_dict[split].to_pandas()  # Convert the split to a pandas DataFrame

    # Unique datasets
    unique_datasets = df['dataset'].nunique()

    # Unique premises, hypotheses, and augmented hypotheses
    unique_premises = df['premise'].nunique()
    premises = len(df['premise'])
    unique_hypotheses = df['hypothesis'].nunique()
    unique_augmented_hypotheses = df['augmented_hypothesis'].nunique()

    # Average length of premises in terms of word count
    avg_premise_word_count = df['premise'].apply(lambda x: len(x.split())).mean()
    med_premise_word_count = df['premise'].apply(lambda x: len(x.split())).median()

    # Entailment counts
    entailment_counts = df['entailment'].value_counts().to_dict()

    return {
        'unique_datasets': unique_datasets,
        'unique_premises': unique_premises,
        'total_premises': premises,
        'unique_hypotheses': unique_hypotheses,
        'unique_augmented_hypotheses': unique_augmented_hypotheses,
        'avg_premise_word_count': avg_premise_word_count,
        'med_premise_word_count': med_premise_word_count,
        'entailment_counts': entailment_counts
    }


###########
## Table 1
###########
stance = Dataset.from_pandas(all_datasets[all_datasets['task'] == 'stance detection'])
event = Dataset.from_pandas(all_datasets[all_datasets['task'] == 'event extraction'])
hate = Dataset.from_pandas(all_datasets[all_datasets['task'] == 'hatespeech and toxicity'])
topic = Dataset.from_pandas(all_datasets[all_datasets['task'] == 'topic classification'])

task_ds = DatasetDict({'stance':stance, 'event':event, 'hate':hate, 'topic':topic})

splits = ['stance', 'topic', 'hate', 'event']
results = {}

for split in splits:
    results[split] = describe_dataset(task_ds, split)

# Calculate total unique datasets across all splits
all_datasets = pd.concat([task_ds[split].to_pandas() for split in splits])
total_unique_datasets = all_datasets['dataset'].nunique()
total_documents = all_datasets.shape[0]
total_unique_hypotheses = all_datasets['hypothesis'].nunique()
total_not_entail = all_datasets['entailment'].sum()
total_entail = total_documents - total_not_entail

# Create dataframe to output a table
data_for_df = {
    'Unique Datasets': [results[split]['unique_datasets'] for split in splits],
    'Total Documents': [results[ds]['total_premises'] for ds in splits],
    'Unique Hypotheses': [results[split]['unique_hypotheses'] for split in splits],
    'Median Premise Length': [f"{results[split]['med_premise_word_count']:.2f}" for split in splits],
    'Entail': [results[split]['entailment_counts'].get(0, 0) for split in splits],
    'Not Entail': [results[split]['entailment_counts'].get(1, 0) for split in splits]
}
df = pd.DataFrame(data_for_df, index=splits)

# Add the totals row
df.loc['Total'] = [
    total_unique_datasets,
    total_documents,
    total_unique_hypotheses,
    f"{all_datasets['premise'].apply(lambda x: len(x.split())).median():.2f}",
    total_entail,
    total_not_entail
]

# Save the DataFrame to a text file
output_file = './tables/table_1.txt'
df_string = df.to_string()

with open(output_file, 'w') as f:
    f.write(df_string)

##################################
## Section 3.3, Manual Validation
##################################
manual_val = pd.read_csv('./data/manual_validation.csv')
val_accuracy = manual_val['agree'].sum()/manual_val.shape[0]
disagree_count = manual_val.shape[0] - manual_val['agree'].sum()
reasonable_disagreement = manual_val['reasonable'].sum()
incorrect_count = disagree_count - reasonable_disagreement
kappa = round(cohen_kappa_score(manual_val['entailment'], manual_val['validation']), 2)

print(f"Human-LLM agreement percent: {val_accuracy}")
print(f"Human-LLM Cohen's Kappa: {kappa}")
print(f"Human-LLM disagree count: {disagree_count}")
print(f"Human-LLM reasonable disagreements: {reasonable_disagreement}")
print(f"LLM error count: {incorrect_count}")

########################################
## Section 3.4, Hypothesis Augmentation
########################################
augmented_count = all_datasets['augmented_hypothesis'].nunique()
print(f"Total augmented hypotheses: {augmented_count}")

##########################
## Section 3.5, Data Split
##########################
print(f"Test set size: {test.shape[0]}")
print(f"Validation set size: {val.shape[0]}")
print(f"Training set size: {train.shape[0]}")

val_hyp_in_train = val[val['hypothesis'].isin(train['hypothesis'])].shape[0]
val_hyp_not_train = val[~val['hypothesis'].isin(train['hypothesis'])].shape[0]
test_hyp_in_train = test[test['hypothesis'].isin(train['hypothesis'])].shape[0]

print(f"Validation set documents with hypotheses in the training set: {val_hyp_in_train}")
print(f"Validation set documents with hypotheses not in the training set: {val_hyp_not_train}")
print(f"Test set documents with hypotheses in the training set: {test_hyp_in_train}")